{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a9226af6-32fa-4c3c-966a-e54099fcd09d",
   "metadata": {},
   "source": [
    "# **03:** Function Calling\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b6a555d-75d1-4241-bee3-a4ffaab4ecd2",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "# !pip install pandas \"mistralai>=0.1.2\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "359eb942-9210-4f8b-a37c-54d343ce5a89",
   "metadata": {},
   "source": [
    "### Load API key"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb19dc9d-b07d-4bb2-a8e2-001b88d4f09a",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "from helper import load_mistral_api_key\n",
    "api_key, dlai_endpoint = load_mistral_api_key(ret_key=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ea4f4d1-ddb3-4db4-b9ee-020e563de00b",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b4c0408-94a8-4c3f-b518-253b718f436e",
   "metadata": {
    "height": 251
   },
   "outputs": [],
   "source": [
    "data = {\n",
    "    \"transaction_id\": [\"T1001\", \"T1002\", \"T1003\", \"T1004\", \"T1005\"],\n",
    "    \"customer_id\": [\"C001\", \"C002\", \"C003\", \"C002\", \"C001\"],\n",
    "    \"payment_amount\": [125.50, 89.99, 120.00, 54.30, 210.20],\n",
    "    \"payment_date\": [\n",
    "        \"2021-10-05\",\n",
    "        \"2021-10-06\",\n",
    "        \"2021-10-07\",\n",
    "        \"2021-10-05\",\n",
    "        \"2021-10-08\",\n",
    "    ],\n",
    "    \"payment_status\": [\"Paid\", \"Unpaid\", \"Paid\", \"Paid\", \"Pending\"],\n",
    "}\n",
    "df = pd.DataFrame(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26c215e9-d2b0-4188-bd54-0d406f47d5ce",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7fb00b24-9b3c-4a86-80e6-810c61ae650b",
   "metadata": {},
   "source": [
    "#### How you might answer data questions without function calling\n",
    "- Not recommended, but an example to serve as a contrast to function calling."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22c81d83-1409-484d-9b69-47c0923cdbbc",
   "metadata": {
    "height": 421
   },
   "outputs": [],
   "source": [
    "data = \"\"\"\n",
    "    \"transaction_id\": [\"T1001\", \"T1002\", \"T1003\", \"T1004\", \"T1005\"],\n",
    "    \"customer_id\": [\"C001\", \"C002\", \"C003\", \"C002\", \"C001\"],\n",
    "    \"payment_amount\": [125.50, 89.99, 120.00, 54.30, 210.20],\n",
    "    \"payment_date\": [\n",
    "        \"2021-10-05\",\n",
    "        \"2021-10-06\",\n",
    "        \"2021-10-07\",\n",
    "        \"2021-10-05\",\n",
    "        \"2021-10-08\",\n",
    "    ],\n",
    "    \"payment_status\": [\"Paid\", \"Unpaid\", \"Paid\", \"Paid\", \"Pending\"],\n",
    "}\n",
    "\"\"\"\n",
    "transaction_id = \"T1001\"\n",
    "\n",
    "prompt = f\"\"\"\n",
    "Given the following data, what is the payment status for \\\n",
    " transaction_id={transaction_id}?\n",
    "\n",
    "data:\n",
    "{data}\n",
    "\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12fe08ce-a95e-4736-bfa0-872a4d129675",
   "metadata": {
    "height": 302
   },
   "outputs": [],
   "source": [
    "import os\n",
    "from mistralai.client import MistralClient\n",
    "from mistralai.models.chat_completion import ChatMessage\n",
    "\n",
    "\n",
    "def mistral(user_message, model=\"mistral-small-latest\", is_json=False):\n",
    "    client = MistralClient(api_key=api_key, endpoint=dlai_endpoint)\n",
    "    messages = [ChatMessage(role=\"user\", content=user_message)]\n",
    "\n",
    "    if is_json:\n",
    "        chat_response = client.chat(\n",
    "            model=model, messages=messages, response_format={\"type\": \"json_object\"}\n",
    "        )\n",
    "    else:\n",
    "        chat_response = client.chat(model=model, messages=messages)\n",
    "\n",
    "    return chat_response.choices[0].message.content"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a6a80fc-c05c-4ebf-a2f5-46f1ae0218ce",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "response = mistral(prompt)\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e0ef041-a23a-437a-b36a-58a1b69be5db",
   "metadata": {},
   "source": [
    "## Step 1. User: specify tools and query\n",
    "\n",
    "### Tools\n",
    "\n",
    "- You can define all tools that you might want the model to call."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5520443c-62a8-482f-acfb-8ae9b554e0fc",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f2663fc-80ff-4fd8-a503-eca2f295499d",
   "metadata": {
    "height": 115
   },
   "outputs": [],
   "source": [
    "def retrieve_payment_status(df: data, transaction_id: str) -> str:\n",
    "    if transaction_id in df.transaction_id.values:\n",
    "        return json.dumps(\n",
    "            {\"status\": df[df.transaction_id == transaction_id].payment_status.item()}\n",
    "        )\n",
    "    return json.dumps({\"error\": \"transaction id not found.\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8aa0c617-c62d-481c-a1e0-bc5b4e385558",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "status = retrieve_payment_status(df, transaction_id=\"T1001\")\n",
    "print(status)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39dc3c22-4c71-4aba-a2b0-3ae7cfdf33e5",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "type(status)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9e153c1-6ff6-4878-a8c2-e02bac749210",
   "metadata": {
    "height": 115
   },
   "outputs": [],
   "source": [
    "def retrieve_payment_date(df: data, transaction_id: str) -> str:\n",
    "    if transaction_id in df.transaction_id.values:\n",
    "        return json.dumps(\n",
    "            {\"date\": df[df.transaction_id == transaction_id].payment_date.item()}\n",
    "        )\n",
    "    return json.dumps({\"error\": \"transaction id not found.\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39a3cc95-426b-4459-9ac5-8ce04a48c339",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "date = retrieve_payment_date(df, transaction_id=\"T1002\")\n",
    "print(date)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5728a885-ec7c-4d21-b461-ebb5e29b0a1e",
   "metadata": {},
   "source": [
    "- You can outline the function specifications with a JSON schema."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33b5a13e-b985-4a2f-b4bd-6115eb51c74b",
   "metadata": {
    "height": 302
   },
   "outputs": [],
   "source": [
    "tool_payment_status = {\n",
    "    \"type\": \"function\",\n",
    "    \"function\": {\n",
    "        \"name\": \"retrieve_payment_status\",\n",
    "        \"description\": \"Get payment status of a transaction\",\n",
    "        \"parameters\": {\n",
    "            \"type\": \"object\",\n",
    "            \"properties\": {\n",
    "                \"transaction_id\": {\n",
    "                    \"type\": \"string\",\n",
    "                    \"description\": \"The transaction id.\",\n",
    "                }\n",
    "            },\n",
    "            \"required\": [\"transaction_id\"],\n",
    "        },\n",
    "    },\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50921cc3-bcd3-4027-b246-4946e59acd91",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "type(tool_payment_status)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4974d7f7-4dbe-4bcc-994c-97af22401049",
   "metadata": {
    "height": 302
   },
   "outputs": [],
   "source": [
    "tool_payment_date = {\n",
    "    \"type\": \"function\",\n",
    "    \"function\": {\n",
    "        \"name\": \"retrieve_payment_date\",\n",
    "        \"description\": \"Get payment date of a transaction\",\n",
    "        \"parameters\": {\n",
    "            \"type\": \"object\",\n",
    "            \"properties\": {\n",
    "                \"transaction_id\": {\n",
    "                    \"type\": \"string\",\n",
    "                    \"description\": \"The transaction id.\",\n",
    "                }\n",
    "            },\n",
    "            \"required\": [\"transaction_id\"],\n",
    "        },\n",
    "    },\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c39808b5-6fc2-47fd-a8aa-66c33e9c1271",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "type(tool_payment_status)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "756006e4-ae7e-4023-89d5-7de970f2efe7",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "tools = [tool_payment_status, tool_payment_date]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6e67822-8c7f-4b1b-a568-d29dc111a86c",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "type(tools)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "217ebbe3-10a1-4ca0-9bce-e694c6007e56",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "tools"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bd7d253d-f20c-4f09-a09b-258757ebb8c3",
   "metadata": {},
   "source": [
    "### functools"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "363cf49c-c6dc-48e3-a857-eaae5745c11d",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "import functools"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88d3fda1-72bb-4ee3-a80f-340716a22a01",
   "metadata": {
    "height": 81
   },
   "outputs": [],
   "source": [
    "names_to_functions = {\n",
    "    \"retrieve_payment_status\": functools.partial(retrieve_payment_status, df=df),\n",
    "    \"retrieve_payment_date\": functools.partial(retrieve_payment_date, df=df),\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acffb815-ae8f-45cb-94aa-e657eda1e8f9",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "names_to_functions[\"retrieve_payment_status\"](transaction_id=\"T1001\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08e3d1bd-97f1-4186-a648-7611a4fecca6",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "tools"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e0a05bf6-c244-4124-8196-f9a780fef95d",
   "metadata": {},
   "source": [
    "### User query\n",
    "\n",
    "- Example: “What’s the status of my transaction?”"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a9a9813-5889-4dad-bd4b-107f7f9dc00a",
   "metadata": {
    "height": 98
   },
   "outputs": [],
   "source": [
    "from mistralai.models.chat_completion import ChatMessage\n",
    "\n",
    "chat_history = [\n",
    "    ChatMessage(role=\"user\", content=\"What's the status of my transaction?\")\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d926683-3cc1-4de6-ad53-adefe8d5cc0b",
   "metadata": {},
   "source": [
    "## Step 2. Model: Generate function arguments "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0fa5241d-1ca2-49b6-b0be-b74716dc9b76",
   "metadata": {
    "height": 200
   },
   "outputs": [],
   "source": [
    "from mistralai.client import MistralClient\n",
    "\n",
    "model = \"mistral-large-latest\"\n",
    "\n",
    "client = MistralClient(api_key=os.getenv(\"MISTRAL_API_KEY\"), endpoint=os.getenv(\"DLAI_MISTRAL_API_ENDPOINT\"))\n",
    "\n",
    "response = client.chat(\n",
    "    model=model, messages=chat_history, tools=tools, tool_choice=\"auto\"\n",
    ")\n",
    "\n",
    "response"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "687dcf79-ceea-4828-8c42-943989ace5e3",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "response.choices[0].message.content"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "de9cfb80-6054-4d4f-863b-53209b4aea79",
   "metadata": {},
   "source": [
    "### Save the chat history"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3328d0cd-01df-4cbe-9619-d76fa3cbbbaf",
   "metadata": {
    "height": 98
   },
   "outputs": [],
   "source": [
    "chat_history.append(\n",
    "    ChatMessage(role=\"assistant\", content=response.choices[0].message.content)\n",
    ")\n",
    "chat_history.append(ChatMessage(role=\"user\", content=\"My transaction ID is T1001.\"))\n",
    "chat_history"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18162db2-315c-4f28-a417-7b5a2691c526",
   "metadata": {
    "height": 64
   },
   "outputs": [],
   "source": [
    "response = client.chat(\n",
    "    model=model, messages=chat_history, tools=tools, tool_choice=\"auto\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "277e0819-6ac3-4ff1-9c1d-c5f623f36915",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "response"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4411e4cc-1733-4109-b73a-08c656d7a003",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "response.choices[0].message"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0376db72-0d92-4808-ae42-da52f6544649",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "chat_history.append(response.choices[0].message)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "78e66180-79b1-47aa-a4a9-dc8e855cc8b3",
   "metadata": {},
   "source": [
    "- Notice these fields:\n",
    "- `name='retrieve_payment_status'`\n",
    "- `arguments='{\"transaction_id\": \"T1001\"}'`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ffab43e6-3225-45b4-bc85-3689c2eee7c3",
   "metadata": {},
   "source": [
    "## Step 3. User: Execute function to obtain tool results\n",
    "\n",
    "- Currently, the user is the one who will execute these functions (the model will not execute these functions on its own)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f392bae9-6b79-4571-a773-7425b56edca4",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "tool_function = response.choices[0].message.tool_calls[0].function\n",
    "print(tool_function)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59728fb2-f1ed-4d93-b077-ec88b75723a8",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "tool_function.name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5c4b4e8-58b8-4835-8ba1-e90c90cc5bd2",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "tool_function.arguments"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e5605fb6-d17f-4033-87ad-5f1ad626f8b1",
   "metadata": {},
   "source": [
    "- The function arguments are expected to be in a Python dictionary and not a string.\n",
    "- To make this string into a dictionary, you can use `json.loads()`  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2002b583-21c7-4d03-b3d5-8146bf90cfad",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "args = json.loads(tool_function.arguments)\n",
    "print(args)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "28c6b861-3a5f-4edc-a63e-1a1c778ab94a",
   "metadata": {},
   "source": [
    "- Recall the functools dictionary that you made earlier\n",
    "\n",
    "```Python\n",
    "import functools\n",
    "names_to_functions = {\n",
    "    \"retrieve_payment_status\": \n",
    "      functools.partial(retrieve_payment_status, df=df),\n",
    "    \n",
    "    \"retrieve_payment_date\": \n",
    "      functools.partial(retrieve_payment_date, df=df),\n",
    "}\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2d01ace-4fa2-4e3b-877e-d3e00b36b924",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "function_result = names_to_functions[tool_function.name](**args)\n",
    "function_result"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a63345b0-4ad6-4047-badc-3feb45202d46",
   "metadata": {},
   "source": [
    "- The output of the function call can be saved as a chat message, with the role \"tool\"."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81a14939-b6a9-45ba-92cb-b190969f6889",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "tool_msg = ChatMessage(role=\"tool\", name=tool_function.name, content=function_result)\n",
    "chat_history.append(tool_msg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2dfe4ca1-e32a-4749-80b0-f7db953ead31",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "chat_history"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2aa2e194-1e97-4446-817d-9b57dcfab371",
   "metadata": {},
   "source": [
    "## Step 4. Model: Generate final answer\n",
    "- The model can now reply to the user query, given the information provided by the \"tool\"."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe284553-87cc-4a8d-a42f-13a16325691b",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "response = client.chat(model=model, messages=chat_history)\n",
    "response.choices[0].message.content"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e308a098-704c-48f5-9577-444f381c6feb",
   "metadata": {},
   "source": [
    "### Try it for yourself!\n",
    "- Try asking another question about the data, such as \"how much did I pay my recent order?\n",
    "  - You can be customer \"T1002\", for instance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db6954bb-becb-44c1-b3be-ac8d53e3861c",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
